Argmax

沿指定轴查找最大 topk 个值的索引。当 topk=1 时,该算子等价于 ArgMax

\[Y_i = \underset{k}{\operatorname{argmax}} (X_{slice_i})\]

其中 \(X_{slice_i}\) 是输入张量中沿指定轴的一个切片,函数返回该切片中最大值的索引 \(k\)

输入:
  • input - 输入数据地址。

  • params - 其他参数打包成数组。

  • core_mask - 核掩码。

输出:
  • output - 存储索引的输出张量。

  • output_value - 如果 return_values 为 true,则此处存储找到的值。

支持平台:

FT78NE MT7004

备注

  • FT78NE 支持fp32

  • MT7004 支持fp16, fp32

参数数组结构:

 1long long params[12];
 2params[0] = (long long)in_shape;输入张量的维度信息数组。
 3params[1] = (long long)in_strides;输入张量的步长信息数组。
 4params[2] = (long long)out_strides;输出张量的步长信息数组。
 5params[3] = (long long)arg_elements;用于存放候选值的临时工作空间地址。
 6params[4] = (long long)index;用于存放候选索引的临时工作空间地址。
 7params[5] = (long long)topk;需要查找的最大值的数量。
 8params[6] = (long long)out_value;是否返回数值的标志。若为非0,则 `output_value` 必须提供有效地址。
 9params[7] = (long long)in_shape_size;输入张量的维度数 (即 `in_shape` 数组的长度)。
10params[8] = (long long)axis;执行查找操作的轴。

共享存储版本:

void fp_arg_max_s(float *input, void *output, float *output_value, long long *params, int core_mask)
void hp_arg_max_s(half *input, void *output, half *output_value, long long *params, int core_mask)

C调用示例:

 1//FT78NE示例
 2#include <stdio.h>
 3#include <argmax.h>
 4int main(int argc, char* argv[]) {
 5    float* input = (float*)0x81000000; //需要初始化
 6    void* output = (void*)0x82000000; //不需要初始化
 7    float* output_value = (float*)0x83000000; //可选
 8    float* arg_elements = (float*)0x84000000;//不需要初始化
 9    int* index = (int*)0x85000000; //不需要初始化
10
11    int core_mask = 0b1111;
12
13    int *in_strides = (int*)0x86000000;
14    int *out_strides = (int*)0x86000200;
15
16    int in_shape_size = 4;//最多只考虑4维
17    int in_shape[4] = {4, 8, 16, 8};
18
19    int axis = 1; //要操作的维度,不能超过3
20    int topk = 3; //不超过in_shape[axis]
21
22    int out_value = 0; //是否输出值,1表示输出值,0表示输出索引
23    float *outputfp32 = (float *)output;
24    int *outputint = (int *)output;
25
26    srand(time(0));
27    // 初始化测试数据,包含各种情况
28    int i, j;
29    //tensor 1 int32
30    int in_total_elements = in_shape[0] * in_shape[1] * in_shape[2] * in_shape[3];
31    for(i = 0; i < in_total_elements; i ++) {
32        input[i] = (float)(rand()%100);
33    }
34
35    long long params[9];
36    params[0] = (long long)in_shape;
37    params[1] = (long long)in_strides;
38    params[2] = (long long)out_strides;
39    params[3] = (long long)arg_elements;
40    params[4] = (long long)index;
41    params[5] = (long long)topk;
42    params[6] = (long long)out_value;
43    params[7] = (long long)in_shape_size;
44    params[8] = (long long)axis;
45
46    fp_arg_max_s(input, output, output_value, params, core_mask);
47    return 0;
48}

私有存储版本:

void fp_arg_max_p(float *input, void *output, float *output_value, long long *params)
void hp_arg_max_p(half *input, void *output, half *output_value, long long *params)

C调用示例:

 1//FT78NE示例
 2#include <stdio.h>
 3#include <argmax.h>
 4int main(int argc, char* argv[]) {
 5    float* input = (float*)0x10010000; //需要初始化
 6    void* output = (void*)0x10020000; //不需要初始化
 7    float* output_value = (float*)0x10030000; //可选
 8    float* arg_elements = (float*)0x10040000;//不需要初始化
 9    int* index = (int*)0x10050000; //不需要初始化
10
11    int *in_strides = (int*)0x1004E000;
12    int *out_strides = (int*)0x1004E200;
13
14    int in_shape_size = 4;//最多只考虑4维
15    int in_shape[4] = {4, 8, 16, 8};
16
17    int axis = 1; //要操作的维度,不能超过3
18    int topk = 3; //不超过in_shape[axis]
19
20    int out_value = 0; //是否输出值,1表示输出值,0表示输出索引
21    float *outputfp32 = (float *)output;
22    int *outputint = (int *)output;
23
24    srand(time(0));
25    // 初始化测试数据,包含各种情况
26    int i, j;
27    //tensor 1 int32
28    int in_total_elements = in_shape[0] * in_shape[1] * in_shape[2] * in_shape[3];
29    for(i = 0; i < in_total_elements; i ++) {
30        input[i] = (float)(rand()%100);
31    }
32
33    long long params[9];
34    params[0] = (long long)in_shape;
35    params[1] = (long long)in_strides;
36    params[2] = (long long)out_strides;
37    params[3] = (long long)arg_elements;
38    params[4] = (long long)index;
39    params[5] = (long long)topk;
40    params[6] = (long long)out_value;
41    params[7] = (long long)in_shape_size;
42    params[8] = (long long)axis;
43
44    fp_arg_max_p(input, output, output_value, params);
45    return 0;
46}